{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from lib import models, graph, coarsening, utils\n",
    "\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.sparse\n",
    "import numpy as np\n",
    "import time, shutil\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "flags = tf.app.flags\n",
    "FLAGS = flags.FLAGS\n",
    "\n",
    "# Graphs.\n",
    "flags.DEFINE_integer('number_edges', 16, 'Graph: minimum number of edges per vertex.')\n",
    "flags.DEFINE_string('metric', 'cosine', 'Graph: similarity measure (between features).')\n",
    "# TODO: change cgcnn for combinatorial Laplacians.\n",
    "flags.DEFINE_bool('normalized_laplacian', True, 'Graph Laplacian: normalized.')\n",
    "flags.DEFINE_integer('coarsening_levels', 0, 'Number of coarsened graphs.')\n",
    "\n",
    "flags.DEFINE_string('dir_data', os.path.join('data', 'rcv1'), 'Directory to store data.')\n",
    "flags.DEFINE_integer('val_size', 400, 'Size of the validation set.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data\n",
    "\n",
    "**From Dropout (Bruna did the same).**\n",
    "We took the dataset and split it into 63 classes based on the the 63 categories at the second-level of the category tree. We removed 11 categories that did not have any data and one category that had only 4 training examples. We also removed one category that covered a huge chunk (25%) of the examples. This left us with 50 classes and 402,738 documents. We divided the documents into equal-sized training and test sets randomly. Each document was represented\n",
    "using the 2000 most frequent non-stopwords in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Fetch dataset from Scikit-learn.\n",
    "dataset = utils.TextRCV1(data_home=FLAGS.dir_data)\n",
    "\n",
    "# Pre-processing: transform everything to a-z and whitespace.\n",
    "#print(train.show_document(1)[:400])\n",
    "#train.clean_text(num='substitute')\n",
    "\n",
    "# Analyzing / tokenizing: transform documents to bags-of-words.\n",
    "#stop_words = set(sklearn.feature_extraction.text.ENGLISH_STOP_WORDS)\n",
    "# Or stop words from NLTK.\n",
    "# Add e.g. don, ve.\n",
    "#train.vectorize(stop_words='english')\n",
    "#print(train.show_document(1)[:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Selection of classes.\n",
    "keep = ['C11','C12','C13','C14','C15','C16','C17','C18','C21','C22','C23','C24',\n",
    "        'C31','C32','C33','C34','C41','C42','E11','E12','E13','E14','E21','E31',\n",
    "        'E41','E51','E61','E71','G15','GCRIM','GDEF','GDIP','GDIS','GENT','GENV',\n",
    "        'GFAS','GHEA','GJOB','GMIL','GOBIT','GODD','GPOL','GPRO','GREL','GSCI',\n",
    "        'GSPO','GTOUR','GVIO','GVOTE','GWEA','GWELF','M11','M12','M13','M14']\n",
    "assert len(keep) == 55  # There is 55 second-level categories according to LYRL2004.\n",
    "keep.remove('C15')   # 151785 documents\n",
    "keep.remove('GMIL')  # 5 documents only\n",
    "\n",
    "dataset.show_doc_per_class()\n",
    "dataset.show_classes_per_doc()\n",
    "dataset.remove_classes(keep)\n",
    "dataset.show_doc_per_class(True)\n",
    "dataset.show_classes_per_doc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Remove documents with multiple classes.\n",
    "dataset.select_documents()\n",
    "dataset.data_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Remove short documents.\n",
    "#train.data_info(True)\n",
    "#wc = train.remove_short_documents(nwords=20, vocab='full')\n",
    "#train.data_info()\n",
    "#print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n",
    "#plt.figure(figsize=(17,5))\n",
    "#plt.semilogy(wc, '.');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Feature selection.\n",
    "# Other options include: mutual information or document count.\n",
    "#freq = train.keep_top_words(1000, 20)\n",
    "#train.data_info()\n",
    "#train.show_document(1)\n",
    "#plt.figure(figsize=(17,5))\n",
    "#plt.semilogy(freq);\n",
    "\n",
    "# Remove documents whose signal would be the zero vector.\n",
    "#wc = train.remove_short_documents(nwords=5, vocab='selected')\n",
    "#train.data_info(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#dataset.normalize(norm='l1')\n",
    "dataset.show_document(1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Word embedding\n",
    "#if True:\n",
    "#    train.embed()\n",
    "#else:\n",
    "#    train.embed('data_word2vec/GoogleNews-vectors-negative300.bin')\n",
    "#train.data_info()\n",
    "# Further feature selection. (TODO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "perm = np.random.RandomState(seed=42).permutation(dataset.data.shape[0])\n",
    "Ntest = dataset.data.shape[0] // 2\n",
    "perm_test = perm[:Ntest]\n",
    "perm_train = perm[Ntest:]\n",
    "train_data = dataset.data[perm_train,:].astype(np.float32)\n",
    "test_data = dataset.data[perm_test,:].astype(np.float32)\n",
    "train_labels = dataset.labels[perm_train]\n",
    "test_labels = dataset.labels[perm_test]\n",
    "\n",
    "if False:\n",
    "    graph_data = train.embeddings.astype(np.float32)\n",
    "else:\n",
    "    graph_data = dataset.data.T.astype(np.float32)\n",
    "\n",
    "#del dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "t_start = time.process_time()\n",
    "dist, idx = graph.distance_lshforest(graph_data.astype(np.float64), k=FLAGS.number_edges, metric=FLAGS.metric)\n",
    "A = graph.adjacency(dist.astype(np.float32), idx)\n",
    "print(\"{} > {} edges\".format(A.nnz//2, FLAGS.number_edges*graph_data.shape[0]//2))\n",
    "A = graph.replace_random_edges(A, 0)\n",
    "graphs, perm = coarsening.coarsen(A, levels=FLAGS.coarsening_levels, self_connections=False)\n",
    "L = [graph.laplacian(A, normalized=True) for A in graphs]\n",
    "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n",
    "#graph.plot_spectrum(L)\n",
    "#del graph_data, A, dist, idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "assert FLAGS.coarsening_levels is 0\n",
    "#t_start = time.process_time()\n",
    "#train_data = scipy.sparse.csr_matrix(coarsening.perm_data(train_data.toarray(), perm))\n",
    "#test_data = scipy.sparse.csr_matrix(coarsening.perm_data(test_data.toarray(), perm))\n",
    "#print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n",
    "#del perm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Training set is shuffled already.\n",
    "#perm = np.random.permutation(train_data.shape[0])\n",
    "#train_data = train_data[perm,:]\n",
    "#train_labels = train_labels[perm]\n",
    "\n",
    "# Validation set.\n",
    "if False:\n",
    "    val_data = train_data[:FLAGS.val_size,:]\n",
    "    val_labels = train_labels[:FLAGS.val_size]\n",
    "    train_data = train_data[FLAGS.val_size:,:]\n",
    "    train_labels = train_labels[FLAGS.val_size:]\n",
    "else:\n",
    "    val_data = test_data\n",
    "    val_labels = test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if False:\n",
    "    utils.baseline(train_data, train_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "common = {}\n",
    "common['dir_name']       = 'rcv1/'\n",
    "common['num_epochs']     = 4\n",
    "common['batch_size']     = 100\n",
    "common['decay_steps']    = len(train_labels) / common['batch_size']\n",
    "common['eval_frequency'] = 200\n",
    "common['filter']         = 'chebyshev5'\n",
    "common['brelu']          = 'b1relu'\n",
    "common['pool']           = 'mpool1'\n",
    "C = max(train_labels) + 1  # number of classes\n",
    "\n",
    "model_perf = utils.model_perf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 1e3\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [2500, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'fc_fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [2500, 500, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'cgconv_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 1e-3\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.999\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [1]\n",
    "    params['K']              = [5]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'cgconv_fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.999\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [5]\n",
    "    params['K']              = [15]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [100, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model_perf.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
